{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyNNUL2WHcN0SX5H8gm9pLZ2",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/thu-ml/unidiffuser/blob/main/UniDiffuser.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# [One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale](https://arxiv.org/abs/2303.06555)\n",
        "\n",
        "This is a demo for sampling from [UniDiffuser](https://arxiv.org/abs/2303.06555)  . UniDiffuser is a unified diffusion framework to fit all distributions relevant to a set of multi-modal data in one model. Implemented on large-scale paired image-text data, UniDiffuser is able to perform image, text, text-to-image, image-to-text, and image-text pair generation.\n",
        "\n",
        "[Paper](https://arxiv.org/abs/2303.06555) | [GitHub](https://github.com/thu-ml/unidiffuser)"
      ],
      "metadata": {
        "id": "-vRJ-KSH5334"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Dependency and Pretrained Models"
      ],
      "metadata": {
        "id": "vEEqy9bRPYKY"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download repository and install dependence"
      ],
      "metadata": {
        "id": "rgAzXSsIA_wS"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!git clone https://github.com/thu-ml/unidiffuser.git\n",
        "!git clone https://github.com/openai/CLIP.git\n",
        "\n",
        "!pip install -e ./CLIP\n",
        "!pip install accelerate==0.12.0 absl-py ml_collections einops ftfy==6.1.1 transformers==4.23.1\n",
        "\n",
        "!pip install -U xformers\n",
        "!pip install -U --pre triton\n",
        "\n",
        "import sys\n",
        "sys.path.append(\".\")\n",
        "sys.path.append('/content/CLIP')"
      ],
      "metadata": {
        "id": "_txRP2VTAy3P"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download pretrained models from HuggingFace"
      ],
      "metadata": {
        "id": "YV-uS9KaJ8sw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "os.chdir('/content/unidiffuser')\n",
        "\n",
        "!mkdir models\n",
        "%cd models\n",
        "!wget -c https://huggingface.co/thu-ml/unidiffuser-v1/resolve/main/autoencoder_kl.pth \n",
        "!wget -c https://huggingface.co/thu-ml/unidiffuser-v1/resolve/main/caption_decoder.pth\n",
        "!wget -c https://huggingface.co/thu-ml/unidiffuser-v1/resolve/main/uvit_v1.pth\n",
        "%cd ..\n"
      ],
      "metadata": {
        "id": "p7TqlfbVKFjC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's also check what type of GPU we've got."
      ],
      "metadata": {
        "id": "zVXoPmHicHLZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!nvidia-smi"
      ],
      "metadata": {
        "id": "fH4JqaIxVe1R"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Import what we need"
      ],
      "metadata": {
        "id": "hUTfjc225L1o"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import ml_collections\n",
        "import torch\n",
        "import random\n",
        "import utils\n",
        "from dpm_solver_pp import NoiseScheduleVP, DPM_Solver\n",
        "from absl import logging\n",
        "import einops\n",
        "import libs.autoencoder\n",
        "import libs.clip\n",
        "from torchvision.utils import save_image, make_grid\n",
        "import torchvision.transforms as standard_transforms\n",
        "import numpy as np\n",
        "import clip\n",
        "from PIL import Image"
      ],
      "metadata": {
        "id": "wE_YD_A_5OAq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Load models"
      ],
      "metadata": {
        "id": "yms6SlIC3qnK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from libs.uvit_multi_post_ln_v1 import UViT\n",
        "\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "nnet = UViT(\n",
        "    img_size=64,\n",
        "    in_chans=4,\n",
        "    patch_size=2,\n",
        "    embed_dim=1536,\n",
        "    depth=30,\n",
        "    num_heads=24,\n",
        "    text_dim=64,\n",
        "    num_text_tokens=77,\n",
        "    clip_img_dim=512,\n",
        "    use_checkpoint=True\n",
        ")\n",
        "nnet.to(device)\n",
        "nnet.load_state_dict(torch.load('models/uvit_v1.pth', map_location='cpu'))\n",
        "nnet.eval()\n",
        "\n",
        "\n",
        "from libs.caption_decoder import CaptionDecoder\n",
        "caption_decoder = CaptionDecoder(device=device, pretrained_path=\"models/caption_decoder.pth\", hidden_dim=64)\n",
        "\n",
        "clip_text_model = libs.clip.FrozenCLIPEmbedder(device=device)\n",
        "clip_text_model.eval()\n",
        "clip_text_model.to(device)\n",
        "\n",
        "autoencoder = libs.autoencoder.get_model(pretrained_path='models/autoencoder_kl.pth')\n",
        "autoencoder.to(device)\n",
        "\n",
        "clip_img_model, clip_img_model_preprocess = clip.load(\"ViT-B/32\", device=device, jit=False)\n",
        "\n",
        "@torch.cuda.amp.autocast()\n",
        "def encode(_batch):\n",
        "    return autoencoder.encode(_batch)\n",
        "\n",
        "@torch.cuda.amp.autocast()\n",
        "def decode(_batch):\n",
        "    return autoencoder.decode(_batch)\n"
      ],
      "metadata": {
        "id": "JGDO6IUI3qR2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Prepare"
      ],
      "metadata": {
        "id": "lr5JaWajQD3O"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Define required function"
      ],
      "metadata": {
        "id": "Jia6mVVDGjYA"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):\n",
        "    _betas = (\n",
        "        torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2\n",
        "    )\n",
        "    return _betas.numpy()\n",
        "_betas = stable_diffusion_beta_schedule()\n",
        "N = len(_betas)\n",
        "\n",
        "def split(x):\n",
        "    C, H, W = 4, 64, 64\n",
        "    z_dim = C * H * W\n",
        "    z, clip_img = x.split([z_dim, 512], dim=1)\n",
        "    z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)\n",
        "    clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=512)\n",
        "    return z, clip_img\n",
        "\n",
        "def combine(z, clip_img):\n",
        "    z = einops.rearrange(z, 'B C H W -> B (C H W)')\n",
        "    clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')\n",
        "    return torch.concat([z, clip_img], dim=-1)\n",
        "\n",
        "def combine_joint(z, clip_img, text):\n",
        "    z = einops.rearrange(z, 'B C H W -> B (C H W)')\n",
        "    clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')\n",
        "    text = einops.rearrange(text, 'B L D -> B (L D)')\n",
        "    return torch.concat([z, clip_img, text], dim=-1)\n",
        "\n",
        "def split_joint(x):\n",
        "    C, H, W = 4, 64, 64\n",
        "    z_dim = C * H * W\n",
        "    z, clip_img, text = x.split([z_dim, 512, 77 * 64], dim=1)\n",
        "    z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)\n",
        "    clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=512)\n",
        "    text = einops.rearrange(text, 'B (L D) -> B L D', L=77, D=64)\n",
        "    return z, clip_img, text\n",
        "\n",
        "def unpreprocess(v):  # to B C H W and [0, 1]\n",
        "    v = 0.5 * (v + 1.)\n",
        "    v.clamp_(0., 1.)\n",
        "    return v\n",
        "\n",
        "\n",
        "def set_seed(seed: int):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "def watermarking(save_path):\n",
        "    img_pre = Image.open(save_path)\n",
        "    img_pos = utils.add_water(img_pre)\n",
        "    img_pos.save(save_path)"
      ],
      "metadata": {
        "id": "w9Ui064gGnlM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Sample from UniDiffuser"
      ],
      "metadata": {
        "id": "dDUqUaMMQL-0"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Define hyperparameters"
      ],
      "metadata": {
        "id": "2Z1x1Wbu-AtC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "mode = \"t2i\" #@param {type:\"string\"}\n",
        "\"\"\"\n",
        "t2i: text-to-image generation\n",
        "i2t: image-to-text generation\n",
        "joint: joint generation\n",
        "i: image generation\n",
        "t: text generation\n",
        "t2i2t: text variation\n",
        "i2t2i: image variation\n",
        "\"\"\"\n",
        "assert mode in ['t2i', 'i2t', 'joint', 't', 'i', 't2i2t', 'i2t2i']\n",
        "prompt = \"an elephant under the sea\" #@param {type:\"string\"}\n",
        "img = 'assets/space.jpg' #@param {type:\"string\"}\n",
        "seed = 1234 #@param {type:\"number\"}\n",
        "steps = 50 #@param {type:\"slider\", min:0, max:100, step:1}\n",
        "cfg_scale = 8 #@param {type:\"slider\", min:0, max:10, step:0.1}\n",
        "n_samples = 4 #@param {type:\"number\"}\n",
        "nrow = 2 #@param {type:\"number\"}\n",
        "data_type = 1\n",
        "output_path = 'out'\n",
        "\n",
        "if mode == 't2i' or mode == 't2i2t':\n",
        "  prompts = [ prompt ] * n_samples\n",
        "  contexts = clip_text_model.encode(prompts)\n",
        "  contexts_low_dim = caption_decoder.encode_prefix(contexts)\n",
        "elif mode == 'i2t' or mode == 'i2t2i':\n",
        "  from PIL import Image\n",
        "  img_contexts = []\n",
        "  clip_imgs = []\n",
        "\n",
        "  def get_img_feature(image):\n",
        "      image = np.array(image).astype(np.uint8)\n",
        "      image = utils.center_crop(512, 512, image)\n",
        "      clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device))\n",
        "\n",
        "      image = (image / 127.5 - 1.0).astype(np.float32)\n",
        "      image = einops.rearrange(image, 'h w c -> 1 c h w')\n",
        "      image = torch.tensor(image, device=device)\n",
        "      moments = autoencoder.encode_moments(image)\n",
        "\n",
        "      return clip_img_feature, moments\n",
        "\n",
        "  image = Image.open(img).convert('RGB')\n",
        "  clip_img, img_context = get_img_feature(image)\n",
        "\n",
        "  img_contexts.append(img_context)\n",
        "  clip_imgs.append(clip_img)\n",
        "  img_contexts = img_contexts * n_samples\n",
        "  clip_imgs = clip_imgs * n_samples\n",
        "\n",
        "  img_contexts = torch.concat(img_contexts, dim=0)\n",
        "  z_img = autoencoder.sample(img_contexts)\n",
        "  clip_imgs = torch.stack(clip_imgs, dim=0)\n",
        "\n"
      ],
      "metadata": {
        "id": "ZptJaHbG65zW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Define the required functions"
      ],
      "metadata": {
        "id": "r6optBAy9U12"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def t2i_nnet(x, timesteps, text):  # text is the low dimension version of the text clip embedding\n",
        "    \"\"\"\n",
        "    1. calculate the conditional model output\n",
        "    2. calculate unconditional model output\n",
        "        config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string\n",
        "        config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method\n",
        "    3. return linear combination of conditional output and unconditional output\n",
        "    \"\"\"\n",
        "    z, clip_img = split(x)\n",
        "\n",
        "    t_text = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)\n",
        "\n",
        "    z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text,\n",
        "                        data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + data_type)\n",
        "    x_out = combine(z_out, clip_img_out)\n",
        "\n",
        "    if cfg_scale == 0.:\n",
        "        return x_out\n",
        "\n",
        "    text_N = torch.randn_like(text)  # 3 other possible choices\n",
        "    z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=text_N, t_img=timesteps, t_text=torch.ones_like(timesteps) * N,\n",
        "                                  data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + data_type)\n",
        "    x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)\n",
        "    \n",
        "\n",
        "    return x_out + cfg_scale * (x_out - x_out_uncond)\n",
        "\n",
        "def i_nnet(x, timesteps):\n",
        "    z, clip_img = split(x)\n",
        "    text = torch.randn(x.size(0), 77, 64, device=device)\n",
        "    t_text = torch.ones_like(timesteps) * N\n",
        "    z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text,\n",
        "                          data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + data_type)\n",
        "    x_out = combine(z_out, clip_img_out)\n",
        "    return x_out\n",
        "\n",
        "def t_nnet(x, timesteps):\n",
        "    z = torch.randn(x.size(0), *[4, 64, 64], device=device)\n",
        "    clip_img = torch.randn(x.size(0), 1, 512, device=device)\n",
        "    z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,\n",
        "                        data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)\n",
        "    return text_out\n",
        "\n",
        "def i2t_nnet(x, timesteps, z, clip_img):\n",
        "    \"\"\"\n",
        "    1. calculate the conditional model output\n",
        "    2. calculate unconditional model output\n",
        "    3. return linear combination of conditional output and unconditional output\n",
        "    \"\"\"\n",
        "    t_img = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)\n",
        "\n",
        "    z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=t_img, t_text=timesteps,\n",
        "                        data_type=torch.zeros_like(t_img, device=device, dtype=torch.int) + data_type)\n",
        "\n",
        "    if cfg_scale == 0.:\n",
        "        return text_out\n",
        "\n",
        "    z_N = torch.randn_like(z)  # 3 other possible choices\n",
        "    clip_img_N = torch.randn_like(clip_img)\n",
        "    z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z_N, clip_img_N, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,\n",
        "                                    data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)\n",
        "\n",
        "    return text_out + cfg_scale * (text_out - text_out_uncond)\n",
        "\n",
        "def joint_nnet(x, timesteps):\n",
        "    z, clip_img, text = split_joint(x)\n",
        "    z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=timesteps,\n",
        "                        data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)\n",
        "    x_out = combine_joint(z_out, clip_img_out, text_out)\n",
        "\n",
        "    if cfg_scale == 0.:\n",
        "        return x_out\n",
        "\n",
        "    z_noise = torch.randn(x.size(0), *(4, 64, 64), device=device)\n",
        "    clip_img_noise = torch.randn(x.size(0), 1, 512, device=device)\n",
        "    text_noise = torch.randn(x.size(0), 77, 64, device=device)\n",
        "\n",
        "    _, _, text_out_uncond = nnet(z_noise, clip_img_noise, text=text, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,\n",
        "                      data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)\n",
        "    z_out_uncond, clip_img_out_uncond, _ = nnet(z, clip_img, text=text_noise, t_img=timesteps, t_text=torch.ones_like(timesteps) * N,\n",
        "                            data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + data_type)\n",
        "\n",
        "    x_out_uncond = combine_joint(z_out_uncond, clip_img_out_uncond, text_out_uncond)\n",
        "\n",
        "    return x_out + cfg_scale * (x_out - x_out_uncond)\n",
        "\n",
        "\n",
        "def sample_fn(mode, **kwargs):\n",
        "\n",
        "    _z_init = torch.randn(n_samples, *(4, 64, 64), device=device)\n",
        "    _clip_img_init = torch.randn(n_samples, 1, 512, device=device)\n",
        "    _text_init = torch.randn(n_samples, 77, 64, device=device)\n",
        "    if mode == 'joint':\n",
        "        _x_init = combine_joint(_z_init, _clip_img_init, _text_init)\n",
        "    elif mode in ['t2i', 'i']:\n",
        "        _x_init = combine(_z_init, _clip_img_init)\n",
        "    elif mode in ['i2t', 't']:\n",
        "        _x_init = _text_init\n",
        "    noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())\n",
        "\n",
        "    def model_fn(x, t_continuous):\n",
        "        t = t_continuous * N\n",
        "        if mode == 'joint':\n",
        "            return joint_nnet(x, t)\n",
        "        elif mode == 't2i':\n",
        "            return t2i_nnet(x, t, **kwargs)\n",
        "        elif mode == 'i2t':\n",
        "            return i2t_nnet(x, t, **kwargs)\n",
        "        elif mode == 'i':\n",
        "            return i_nnet(x, t)\n",
        "        elif mode == 't':\n",
        "            return t_nnet(x, t)\n",
        "\n",
        "    dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)\n",
        "    with torch.no_grad():\n",
        "        with torch.autocast(device_type=device):\n",
        "            x = dpm_solver.sample(_x_init, steps=steps, eps=1. / N, T=1.)\n",
        "\n",
        "    os.makedirs(output_path, exist_ok=True)\n",
        "    if mode == 'joint':\n",
        "        _z, _clip_img, _text = split_joint(x)\n",
        "        return _z, _clip_img, _text\n",
        "    elif mode in ['t2i', 'i']:\n",
        "        _z, _clip_img = split(x)\n",
        "        return _z, _clip_img\n",
        "    elif mode in ['i2t', 't']:\n",
        "        return x\n"
      ],
      "metadata": {
        "id": "PQFyP5Vu9Ub5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Sample"
      ],
      "metadata": {
        "id": "SV0uxXQ6CEoZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "set_seed(seed)\n",
        "def show(path):\n",
        "  samples = Image.open(path)\n",
        "  display(samples)\n",
        "\n",
        "if mode in ['joint']:\n",
        "    _z, _clip_img, _text = sample_fn(mode)\n",
        "    samples = unpreprocess(decode(_z))\n",
        "    prompts = caption_decoder.generate_captions(_text)\n",
        "    os.makedirs(os.path.join(output_path, mode), exist_ok=True)\n",
        "    print(prompts)\n",
        "    with open(os.path.join(output_path, mode, 'prompts.txt'), 'w') as f:\n",
        "        print('\\n'.join(prompts), file=f)\n",
        "    for idx, sample in enumerate(samples):\n",
        "        save_path = os.path.join(output_path, mode, f'{idx}.png')\n",
        "        save_image(sample, save_path)\n",
        "        watermarking(save_path)\n",
        "    # save a grid of generated images\n",
        "    samples_pos = []\n",
        "    for idx, sample in enumerate(samples):\n",
        "        sample_pil = standard_transforms.ToPILImage()(sample)\n",
        "        sample_pil = utils.add_water(sample_pil)\n",
        "        sample = standard_transforms.ToTensor()(sample_pil)\n",
        "        samples_pos.append(sample)\n",
        "    samples = make_grid(samples_pos, nrow)\n",
        "    save_path = os.path.join(output_path, mode, f'grid.png')\n",
        "    save_image(samples, save_path)\n",
        "    show(save_path)\n",
        "\n",
        "elif mode in ['t2i', 'i', 'i2t2i']:\n",
        "    if mode == 't2i':\n",
        "        _z, _clip_img = sample_fn(mode, text=contexts_low_dim)  # conditioned on the text embedding\n",
        "    elif mode == 'i':\n",
        "        _z, _clip_img = sample_fn(mode)\n",
        "    elif mode == 'i2t2i':\n",
        "        _text = sample_fn('i2t', z=z_img, clip_img=clip_imgs)  # conditioned on the image embedding\n",
        "        _z, _clip_img = sample_fn('t2i', text=_text)\n",
        "    samples = unpreprocess(decode(_z))\n",
        "    os.makedirs(os.path.join(output_path, mode), exist_ok=True)\n",
        "    for idx, sample in enumerate(samples):\n",
        "        save_path = os.path.join(output_path, mode, f'{idx}.png')\n",
        "        save_image(sample, save_path)\n",
        "        watermarking(save_path)\n",
        "    # save a grid of generated images\n",
        "    samples_pos = []\n",
        "    for idx, sample in enumerate(samples):\n",
        "        sample_pil = standard_transforms.ToPILImage()(sample)\n",
        "        sample_pil = utils.add_water(sample_pil)\n",
        "        sample = standard_transforms.ToTensor()(sample_pil)\n",
        "        samples_pos.append(sample)\n",
        "    samples = make_grid(samples_pos, nrow)\n",
        "    save_path = os.path.join(output_path, mode, f'grid.png')\n",
        "    save_image(samples, save_path)\n",
        "    show(save_path)\n",
        "\n",
        "\n",
        "elif mode in ['i2t', 't', 't2i2t']:\n",
        "    if mode == 'i2t':\n",
        "        _text = sample_fn(mode, z=z_img, clip_img=clip_imgs)  # conditioned on the image embedding\n",
        "    elif mode == 't':\n",
        "        _text = sample_fn(mode)\n",
        "    elif mode == 't2i2t':\n",
        "        _z, _clip_img = sample_fn('t2i', text=contexts_low_dim)\n",
        "        _text = sample_fn('i2t', z=_z, clip_img=_clip_img)\n",
        "    samples = caption_decoder.generate_captions(_text)\n",
        "    print(samples)\n",
        "    logging.info(samples)\n",
        "    os.makedirs(os.path.join(output_path, mode), exist_ok=True)\n",
        "    with open(os.path.join(output_path, mode, f'{mode}.txt'), 'w') as f:\n",
        "        print('\\n'.join(samples), file=f)"
      ],
      "metadata": {
        "id": "ip4LuG_rCGqD"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}